# coding=utf-8

"""
    @header mypredict.py
    @abstract   

    @MyBlog: http://www.kuture.com.cn
    @author  Created by Kuture on 2021/8/9
    @version 1.0.0 2021/8/9 Creation(    
    Copyright © 2021年 Mr.Li All rights reserved
"""
import argparse

import cv2
import numpy as np
import torch
from PIL import Image
from packnet_sfm.models.model_wrapper import ModelWrapper
from packnet_sfm.datasets.augmentations import resize_image, to_tensor
from packnet_sfm.utils.horovod import hvd_init, rank
from packnet_sfm.utils.config import parse_test_file
from packnet_sfm.utils.depth import viz_inv_depth


class DepthPredictProcessor(object):

    def __init__(self, model_file, img_half=True):

        # 初始化模型
        hvd_init()
        config, state_dict = parse_test_file(model_file)
        model_wrapper = ModelWrapper(config, load_datasets=False)
        model_wrapper.load_state_dict(state_dict)
        self.dtype = torch.float16 if img_half else None

        # 检测GPU
        if torch.cuda.is_available():
            model_wrapper = model_wrapper.to('cuda:{}'.format(rank()), dtype=self.dtype)

        model_wrapper.eval()
        self.model_wrapper = model_wrapper
        # self.image_shape = config.datasets.augmentation.image_shape
        self.image_shape = (160, 320)  # 192x640, 192x160, 160x320, 160x160

    def predict(self, image):

        image = Image.fromarray(image)

        # Resize and to tensor
        image = resize_image(image, self.image_shape)
        image = to_tensor(image).unsqueeze(0)

        if torch.cuda.is_available():
            image = image.to('cuda:{}'.format(rank()), dtype=self.dtype)

        pred_inv_depth = self.model_wrapper.depth(image)['inv_depths'][0]  # 获取模型预测结果
        rgb = image[0].permute(1, 2, 0).detach().cpu().numpy() * 255  # 生成RGB图
        viz_pred_inv_depth = viz_inv_depth(pred_inv_depth[0]) * 255  # 反转深度值
        image = np.concatenate([rgb, viz_pred_inv_depth], 0)  # 垂直连接RGB与深度信息

        # image = image[:, :, ::-1]
        # display_img = np.uint8(image)

        display_img = np.uint8(viz_pred_inv_depth[:, :, ::-1])

        return display_img

    def camera_display(self):

        # cap = cv2.VideoCapture('/home/kuture/Desktop/test_videos/myRaw3.mp4')
        cap = cv2.VideoCapture(0)
        width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # 宽度
        height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # 高度

        while True:
            ret_val, frame = cap.read()
            if ret_val:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                res_img = self.predict(frame)
                cv2.imshow('', cv2.resize(res_img, (int(width * 1), int(height * 1))))

            if cv2.waitKeyEx(1) == ord('q'):
                break


if __name__ == '__main__':

    model_file = './Data/PackNet01_MR_velsup_CStoK.ckpt'
    # img_path = '/home/kuture/Desktop/002.jpeg'
    # image = cv2.imread(img_path)
    # print(image.shape)
    # image = Image.fromarray(image)

    dep_pred = DepthPredictProcessor(model_file)
    # dep_pred.predict(image)
    dep_pred.camera_display()






































